{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "functional-corrections",
   "metadata": {},
   "source": [
    "## K-means with multi-dimensional data\n",
    " \n",
    "$X_{n \\times d}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "formal-antique",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "durable-horse",
   "metadata": {},
   "outputs": [],
   "source": [
    "n, d, k=1000, 20, 4\n",
    "max_itr=100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "egyptian-omaha",
   "metadata": {},
   "outputs": [],
   "source": [
    "X=np.random.random((n,d))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "employed-helen",
   "metadata": {},
   "source": [
    "$$ argmin_j  ||x_i - c_j||_2 $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "center-timer",
   "metadata": {},
   "outputs": [],
   "source": [
    "def k_means(X, k):\n",
    "    #Randomly Initialize Centroids\n",
    "    np.random.seed(0)\n",
    "    C= X[np.random.randint(n,size=k),:]\n",
    "    E=np.float('inf')\n",
    "    for itr in range(max_itr):\n",
    "        \n",
    "        # Find the distance of each point from the centroids \n",
    "        E_prev=E\n",
    "        E=0\n",
    "        center_idx=np.zeros(n)\n",
    "        for i in range(n):\n",
    "            min_d=np.float('inf')\n",
    "            c=0\n",
    "            for j in range(k):\n",
    "                d=np.linalg.norm(X[i,:]-C[j,:],2)\n",
    "                if d<min_d:\n",
    "                    min_d=d\n",
    "                    c=j\n",
    "            \n",
    "            E+=min_d\n",
    "            center_idx[i]=c\n",
    "            \n",
    "        #Find the new centers\n",
    "        for j in range(k):\n",
    "            C[j,:]=np.mean( X[center_idx==j,:] ,0)\n",
    "        \n",
    "        if itr%10==0:\n",
    "            print(E)\n",
    "        if E_prev==E:\n",
    "            break\n",
    "            \n",
    "    return C, E, center_idx"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "material-hayes",
   "metadata": {},
   "source": [
    "$$ argmin_j  ||x_i - c_j||_2 $$\n",
    "\n",
    "$$||x_i - c_j||_2 = \\sqrt{(x_i - c_j)^T (x_i-c_j)} = \\sqrt{x_i^T x_i -2 x_i^T c_j + c_j^T c_j} $$\n",
    "\n",
    "- $ diag(X~X^T)$, can be used to get $x_i^T x_i$\n",
    "\n",
    "- $X~C^T $, can be used to get $x_i^T c_j$\n",
    "\n",
    "- $diag(C~C^T)$, can be used to get $c_j^T c_j$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "colored-linux",
   "metadata": {},
   "outputs": [],
   "source": [
    "def k_means_vectorized(X, k):\n",
    "    \n",
    "    #Randomly Initialize Centroids\n",
    "    np.random.seed(0)\n",
    "    C= X[np.random.randint(n,size=k),:]\n",
    "    E=np.float('inf')\n",
    "    for itr in range(max_itr):\n",
    "        # Find the distance of each point from the centroids \n",
    "        XX= np.tile(np.diag(np.matmul(X, X.T)), (k,1) ).T\n",
    "        XC=np.matmul(X, C.T)\n",
    "        CC= np.tile(np.diag(np.matmul(C, C.T)), (n,1)) \n",
    "\n",
    "        D= np.sqrt(XX-2*XC+CC)\n",
    "\n",
    "        # Assign the elements to the centroids:\n",
    "        center_idx=np.argmin(D, axis=1)\n",
    "\n",
    "        #Find the new centers\n",
    "        for j in range(k):\n",
    "            C[j,:]=np.mean( X[center_idx==j,:] ,0)\n",
    "\n",
    "        #Find the error\n",
    "        E_prev=E\n",
    "        E=np.sum(D[np.arange(n),center_idx])\n",
    "        if itr%10==0:\n",
    "            print(E)\n",
    "        if E_prev==E:\n",
    "            break\n",
    "    \n",
    "    return C, E, center_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "equivalent-platinum",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1517.502248752696\n",
      "1218.91004301866\n",
      "1217.362137659097\n",
      "0.8816308975219727 seconds\n"
     ]
    }
   ],
   "source": [
    "start=time.time()\n",
    "C, E, center_idx = k_means(X, k)\n",
    "print(time.time()-start,'seconds')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "environmental-steam",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1517.502248752696\n",
      "1218.9100430186547\n",
      "1217.3621376590977\n",
      "0.09020209312438965 seconds\n"
     ]
    }
   ],
   "source": [
    "start=time.time()\n",
    "C, E, center_idx = k_means_vectorized(X, k)\n",
    "print(time.time()-start,'seconds')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "north-picking",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
